import typing
from typing import *
from torch_geometric.typing import *

import torch
from torch import Tensor
import torch_sparse
from torch_sparse import SparseTensor
from torch_geometric.nn.conv.message_passing import *
from {{module}} import *


class Propagate_{{uid}}(NamedTuple):
{%- for k, v in prop_types.items() %}
    {{k}}: {{v}}
{%- endfor %}

{% if edge_updater_types|length > 0 %}
class EdgeUpdater_{{uid}}(NamedTuple):
{%- for k, v in edge_updater_types.items() %}
    {{k}}: {{v}}
{%- endfor %}
{% endif %}

class Collect_{{uid}}(NamedTuple):
{%- for k, v in collect_types.items() %}
    {{k}}: {{v}}
{%- endfor %}

{% if edge_updater_types|length > 0 %}
class EdgeUpdater_Collect_{{uid}}(NamedTuple):
{%- for k, v in edge_collect_types.items() %}
    {{k}}: {{v}}
{%- endfor %}
{% endif %}

class {{cls_name}}({{parent_cls_name}}):

    @torch.jit._overload_method
    def __check_input__(self, edge_index, size):
        # type: (Tensor, Size) -> List[Optional[int]]
        pass

    @torch.jit._overload_method
    def __check_input__(self, edge_index, size):
        # type: (SparseTensor, Size) -> List[Optional[int]]
        pass

{{check_input}}

    @torch.jit._overload_method
    def __lift__(self, src, edge_index, dim):
        # type: (Tensor, Tensor, int) -> Tensor
        pass

    @torch.jit._overload_method
    def __lift__(self, src, edge_index, dim):
        # type: (Tensor, SparseTensor, int) -> Tensor
        pass

{{lift}}

    @torch.jit._overload_method
    def __collect__(self, edge_def, size, kwargs):
        # type: (Tensor, List[Optional[int]], Propagate_{{uid}}) -> Collect_{{uid}}
        pass

    @torch.jit._overload_method
    def __collect__(self, edge_def, size, kwargs):
        # type: (SparseTensor, List[Optional[int]], Propagate_{{uid}}) -> Collect_{{uid}}
        pass

    def __collect__(self, edge_def, size, kwargs):
        init = torch.tensor(0.)
        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for arg in user_args %}
{%- if arg[-2:] not in ['_i', '_j'] %}
        {{arg}} = kwargs.{{arg}}
{%- else %}
        {{arg}}: {{collect_types[arg]}} = {% if collect_types[arg][:8] == 'Optional' %}None{% else %}init{% endif %}
        data = kwargs.{{arg[:-2]}}
        if isinstance(data, (tuple, list)):
            assert len(data) == 2
{%- if arg[-2:] == '_j' %}
            tmp = data[1]
            if isinstance(tmp, Tensor):
                self.__set_size__(size, 1, tmp)
            {{arg}} = data[0]
{%- else %}
            tmp = data[0]
            if isinstance(tmp, Tensor):
                self.__set_size__(size, 0, tmp)
            {{arg}} = data[1]
{%- endif %}
        else:
            {{arg}} = data
        if isinstance({{arg}}, Tensor):
            self.__set_size__(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
            {{arg}} = self.__lift__({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
{%- endif %}
{%- endfor %}

        edge_index: Optional[Tensor] = None
        adj_t: Optional[SparseTensor] = None
        edge_index_i: torch.Tensor = init
        edge_index_j: torch.Tensor = init
        ptr: Optional[Tensor] = None
        if isinstance(edge_def, Tensor):
            edge_index = edge_def
            edge_index_i = edge_def[i]
            edge_index_j = edge_def[j]
        elif isinstance(edge_def, SparseTensor):
            adj_t = edge_def
            edge_index_i = edge_def.storage.row()
            edge_index_j = edge_def.storage.col()
            ptr = edge_def.storage.rowptr()
            {% if 'edge_weight' in collect_types.keys() %}
            if edge_weight is None:
                edge_weight = edge_def.storage.value()
            {% endif %}
            {% if 'edge_attr' in collect_types.keys() %}
            if edge_attr is None:
                edge_attr = edge_def.storage.value()
            {% endif %}
            {% if 'edge_type' in collect_types.keys() %}
            if edge_type is None:
                edge_type = edge_def.storage.value()
            {% endif %}

        {% if collect_types.get('edge_weight', 'Optional')[:8] != 'Optional' %}assert edge_weight is not None{% endif %}
        {% if collect_types.get('edge_attr', 'Optional')[:8] != 'Optional' %}assert edge_attr is not None{% endif %}
        {% if collect_types.get('edge_type', 'Optional')[:8] != 'Optional' %}assert edge_type is not None{% endif %}

        index = edge_index_i
        size_i = size[1] if size[1] is not None else size[0]
        size_j = size[0] if size[0] is not None else size[1]
        dim_size = size_i

        return Collect_{{uid}}({% for k in collect_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})

{% if edge_updater_types|length > 0 %}
    @torch.jit._overload_method
    def __collect_edge__(self, edge_def, size, kwargs):
        # type: (Tensor, List[Optional[int]], EdgeUpdater_{{uid}}) -> EdgeUpdater_Collect_{{uid}}
        pass

    @torch.jit._overload_method
    def __collect_edge__(self, edge_def, size, kwargs):
        # type: (SparseTensor, List[Optional[int]], EdgeUpdater_{{uid}}) -> EdgeUpdater_Collect_{{uid}}
        pass

    def __collect_edge__(self, edge_def, size, kwargs):
        init = torch.tensor(0.)
        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for arg in edge_user_args %}
{%- if arg[-2:] not in ['_i', '_j'] %}
        {{arg}} = kwargs.{{arg}}
{%- else %}
        {{arg}}: {{edge_collect_types[arg]}} = {% if edge_collect_types[arg][:8] == 'Optional' %}None{% else %}init{% endif %}
        data = kwargs.{{arg[:-2]}}
        if isinstance(data, (tuple, list)):
            assert len(data) == 2
{%- if arg[-2:] == '_j' %}
            tmp = data[1]
            if isinstance(tmp, Tensor):
                self.__set_size__(size, 1, tmp)
            {{arg}} = data[0]
{%- else %}
            tmp = data[0]
            if isinstance(tmp, Tensor):
                self.__set_size__(size, 0, tmp)
            {{arg}} = data[1]
{%- endif %}
        else:
            {{arg}} = data
        if isinstance({{arg}}, Tensor):
            self.__set_size__(size, {% if arg[-2:] == '_j'%}0{% else %}1{% endif %}, {{arg}})
            {{arg}} = self.__lift__({{arg}}, edge_def, {% if arg[-2:] == "_j" %}j{% else %}i{% endif %})
{%- endif %}
{%- endfor %}

        edge_index: Optional[Tensor] = None
        adj_t: Optional[SparseTensor] = None
        edge_index_i: torch.Tensor = init
        edge_index_j: torch.Tensor = init
        ptr: Optional[Tensor] = None
        if isinstance(edge_def, Tensor):
            edge_index = edge_def
            edge_index_i = edge_def[i]
            edge_index_j = edge_def[j]
        elif isinstance(edge_def, SparseTensor):
            adj_t = edge_def
            edge_index_i = edge_def.storage.row()
            edge_index_j = edge_def.storage.col()
            ptr = edge_def.storage.rowptr()
            {% if 'edge_weight' in edge_collect_types.keys() %}
            if edge_weight is None:
                edge_weight = edge_def.storage.value()
            {% endif %}
            {% if 'edge_attr' in edge_collect_types.keys() %}
            if edge_attr is None:
                edge_attr = edge_def.storage.value()
            {% endif %}
            {% if 'edge_type' in edge_collect_types.keys() %}
            if edge_type is None:
                edge_type = edge_def.storage.value()
            {% endif %}

        {% if edge_collect_types.get('edge_weight', 'Optional')[:8] != 'Optional' %}assert edge_weight is not None{% endif %}
        {% if edge_collect_types.get('edge_attr', 'Optional')[:8] != 'Optional' %}assert edge_attr is not None{% endif %}
        {% if edge_collect_types.get('edge_type', 'Optional')[:8] != 'Optional' %}assert edge_type is not None{% endif %}

        index = edge_index_i
        size_i = size[1] if size[1] is not None else size[0]
        size_j = size[0] if size[0] is not None else size[1]
        dim_size = size_i

        return EdgeUpdater_Collect_{{uid}}({% for k in edge_collect_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})


{% endif %}

    @torch.jit._overload_method
    def propagate(self, edge_index, {{ prop_types.keys()|join(', ') }}, size=None):
        # type: (Tensor, {{ prop_types.values()|join(', ') }}, Size) -> {{ prop_return_type }}
        pass

    @torch.jit._overload_method
    def propagate(self, edge_index, {{ prop_types.keys()|join(', ') }}, size=None):
        # type: (SparseTensor, {{ prop_types.values()|join(', ') }}, Size) -> {{ prop_return_type }}
        pass

    def propagate(self, edge_index, {{ prop_types.keys()|join(', ') }}, size=None):
        the_size = self.__check_input__(edge_index, size)
        in_kwargs = Propagate_{{uid}}({% for k in prop_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})

        {% if fuse and single_aggr %}
        if isinstance(edge_index, SparseTensor):
            out = self.message_and_aggregate(edge_index{% for k in msg_and_aggr_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
            return self.update(out{% for k in update_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
        {% endif %}

        kwargs = self.__collect__(edge_index, the_size, in_kwargs)
        out = self.message({% for k in msg_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
        {% if single_aggr %}
        out = self.aggregate(out{% for k in aggr_args %}, {{k}}=kwargs.{{k}}{% endfor %})
        {% else %}
        outs: List[Tensor] = []
        for aggr in self.aggrs:
            outs.append(self.aggregate(out{% for k in aggr_args %}, {{k}}=kwargs.{{k}}{% endfor %}, aggr=aggr))
        out = self.combine(outs)
        {% endif %}
        return self.update(out{% for k in update_args %}, {{k}}=kwargs.{{k}}{% endfor %})

{% if edge_updater_types|length > 0 %}
    @torch.jit._overload_method
    def edge_updater(self, edge_index{% for k in edge_updater_types.keys() %}, {{k}} {% endfor %}):
        # type: (Tensor{% for v in edge_updater_types.values() %}, {{v}} {% endfor %}) -> {{ edge_updater_return_type }}
        pass

    @torch.jit._overload_method
    def edge_updater(self, edge_index{% for k in edge_updater_types.keys() %}, {{k}} {% endfor %}):
        # type: (SparseTensor{% for v in edge_updater_types.values() %}, {{v}} {% endfor %}) -> {{ edge_updater_return_type }}
        pass

    def edge_updater(self, edge_index{% for k in edge_updater_types.keys() %}, {{k}} {% endfor %}):
        the_size = self.__check_input__(edge_index, size=None)
        in_kwargs = EdgeUpdater_{{uid}}({% for k in edge_updater_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})
        kwargs = self.__collect_edge__(edge_index, the_size, in_kwargs)
        return self.edge_update({% for k in edge_update_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
{% else %}
    def edge_updater(self):
        pass
{% endif %}

    @property
    def explain(self) -> bool:
        return self._explain

    @explain.setter
    def explain(self, explain: bool):
        raise ValueError("Explainability of message passing modules "
                         "is only supported on the Python module")

{%- for (arg_types, return_type_repr) in forward_types %}

    @torch.jit._overload_method
    {{forward_header}}
        # type: ({{arg_types|join(', ')}}) -> {{return_type_repr}}
        pass
{%- endfor %}

    {{forward_header}}
{{forward_body}}
